{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "from NCDM import NCDM\n",
    "import torch\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "from initial_dataSet import DataSet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataSet: ASSIST_0910\n"
     ]
    }
   ],
   "source": [
    "# ----------基本参数--------------\n",
    "basedir = '../../'\n",
    "dataSet_list = ('ASSIST_0910', 'ASSIST_2017', 'JUNYI', 'MathEC', 'KDDCUP')\n",
    "save_list = ('a0910/', 'a2017/', 'junyi/', 'math_ec/', 'kddcup/')\n",
    "dataSet_idx = 0\n",
    "test_ratio = 0.2\n",
    "batch_size = 32\n",
    "\n",
    "data_set_name = dataSet_list[dataSet_idx]\n",
    "epochs = 8\n",
    "device = 'cuda'\n",
    "# ----------基本参数--------------\n",
    "\n",
    "# ------------数据集--------------\n",
    "dataSet = DataSet(basedir, data_set_name)\n",
    "Q = dataSet.get_exer_conc_adj()\n",
    "user_n = dataSet.student_num\n",
    "item_n = dataSet.exercise_num\n",
    "knowledge_n = dataSet.concept_num\n",
    "\n",
    "data_dir='E:/PY_Project/知识点交互CDM/Experiment/output/'+save_list[dataSet_idx]\n",
    "train_data=pd.read_csv(data_dir+'train.csv')\n",
    "test_data = pd.read_csv(data_dir+'test.csv')\n",
    "# ------------数据集--------------"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 4043/4043 [00:19<00:00, 208.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0] average loss: 0.658732\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|██████████| 4043/4043 [00:17<00:00, 236.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 1] average loss: 0.577612\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2: 100%|██████████| 4043/4043 [00:16<00:00, 241.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 2] average loss: 0.487293\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3: 100%|██████████| 4043/4043 [00:16<00:00, 239.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 3] average loss: 0.447478\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4: 100%|██████████| 4043/4043 [00:17<00:00, 236.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 4] average loss: 0.422915\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5: 100%|██████████| 4043/4043 [00:16<00:00, 241.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 5] average loss: 0.405007\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6: 100%|██████████| 4043/4043 [00:16<00:00, 243.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 6] average loss: 0.392602\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7: 100%|██████████| 4043/4043 [00:16<00:00, 242.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 7] average loss: 0.383038\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating: 100%|██████████| 4007/4007 [00:04<00:00, 923.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc: 0.716779,auc: 0.737908,rmse: 0.446514, mae: 0.327794\n"
     ]
    }
   ],
   "source": [
    "def transform(user, item, Q, score, batch_size):\n",
    "    data_set = TensorDataset(\n",
    "        torch.tensor(user, dtype=torch.int64) - 1,  # (1, user_n) to (0, user_n-1)\n",
    "        torch.tensor(item, dtype=torch.int64) - 1,  # (1, item_n) to (0, item_n-1)\n",
    "        Q[torch.tensor(item, dtype=torch.int64) - 1].float(),\n",
    "        torch.tensor(score, dtype=torch.float32)\n",
    "    )\n",
    "    return DataLoader(data_set, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "\n",
    "train_set, test_set = [\n",
    "    transform(data[\"user_id\"], data[\"item_id\"], Q, data[\"score\"], batch_size)\n",
    "    for data in [train_data.reset_index(), test_data.reset_index()]\n",
    "]\n",
    "\n",
    "logging.getLogger().setLevel(logging.INFO)\n",
    "cdm = NCDM(knowledge_n, item_n, user_n)\n",
    "cdm.train(train_data=train_set, epoch=epochs, device=device)\n",
    "\n",
    "accuracy, auc, rmse, mae = cdm.eval(test_set)\n",
    "print(\"acc: %.6f,auc: %.6f,rmse: %.6f, mae: %.6f\" % (accuracy, auc, rmse, mae))\n",
    "\n",
    "all_stu_idx = dataSet.total_stu_list - 1\n",
    "cognitive_state = torch.sigmoid(cdm.ncdm_net.student_emb(torch.LongTensor(all_stu_idx)))\n",
    "\n",
    "\n",
    "def save_param(save_dir, name, param):\n",
    "    np.savetxt(save_dir + name, param.cpu().detach().numpy(), fmt='%.2f', delimiter=',')\n",
    "\n",
    "\n",
    "save_param('./output/' + save_list[dataSet_idx], 'cognitive_state.csv', cognitive_state)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "527a93331b4b1a8345148922acc34427fb7591433d63b66d32040b6fbbc6d593"
  },
  "kernelspec": {
   "display_name": "Python 3.8.12 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
